Support for WMMA instructions for RDNA4 GPUs#929
Conversation
This commit adds Wave Matrix Multiply Accumulate (WMMA) instruction support for AMD's RDNA4 architecture GPUs (gfx1200+). Changes: - Add WMMA_RDNA4 module in src/device/gcn/wmma_rdna4.jl - Support for new RDNA4 WMMA intrinsics with _gfx12 suffix - Simplified VGPR layout (no data duplication, 8 elements per thread) - Support for Float16 and BFloat16 types (FP8 types ready for future addition) - Add comprehensive tests in test/wmma_rdna4_tests.jl - Update documentation with RDNA4 section and examples - Update existing WMMA tests to also detect RDNA4 Architectural Differences from RDNA3: - Each lane handles 8 elements (vs 16 with duplication in RDNA3) - New intrinsic names with _gfx12 suffix and explicit vector type annotations - Subtarget feature: wmma-128b-insts (vs gfx11-insts for RDNA3) - Cleaner VGPR distribution with no data duplication References: - AMD GPUOpen: https://gpuopen.com/learn/using_matrix_core_amd_rdna4/ - LLVM commit: llvm/llvm-project@829afc4 Generated by Mistral Vibe. Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
Added tile pointer and stride helper functions for WMMA_RDNA4.
Updated example code block to use Julia syntax highlighting.
|
My last commits fail the CI but if I look at the logs there are no errors actually, I don't know if it's related to the updates to the buildkite. |
There was a problem hiding this comment.
AMDGPU.jl Benchmarks
Details
| Benchmark suite | Current: 9ee1ce0 | Previous: 756602c | Ratio |
|---|---|---|---|
amdgpu/synchronization/context/device |
600 ns |
600 ns |
1 |
amdgpu/synchronization/stream/blocking |
260 ns |
240 ns |
1.08 |
amdgpu/synchronization/stream/nonblocking |
340 ns |
340 ns |
1 |
array/accumulate/Float32/1d |
84941 ns |
86251 ns |
0.98 |
array/accumulate/Float32/dims=1 |
383596 ns |
393845 ns |
0.97 |
array/accumulate/Float32/dims=1L |
134982 ns |
131681 ns |
1.03 |
array/accumulate/Float32/dims=2 |
130392 ns |
103022 ns |
1.27 |
array/accumulate/Float32/dims=2L |
2809690 ns |
2827930 ns |
0.99 |
array/accumulate/Int64/1d |
98541 ns |
96412 ns |
1.02 |
array/accumulate/Int64/dims=1 |
288534 ns |
285244 ns |
1.01 |
array/accumulate/Int64/dims=1L |
167452 ns |
160812 ns |
1.04 |
array/accumulate/Int64/dims=2 |
123992 ns |
120772 ns |
1.03 |
array/accumulate/Int64/dims=2L |
2983033 ns |
3014433 ns |
0.99 |
array/broadcast |
133412 ns |
128932 ns |
1.03 |
array/construct |
1680 ns |
1680 ns |
1 |
array/copy |
38771 ns |
39371 ns |
0.98 |
array/copyto!/cpu_to_gpu |
183313 ns |
114832 ns |
1.60 |
array/copyto!/gpu_to_cpu |
183493 ns |
152432 ns |
1.20 |
array/copyto!/gpu_to_gpu |
128142 ns |
88321 ns |
1.45 |
array/iteration/findall/bool |
179452 ns |
181912 ns |
0.99 |
array/iteration/findall/int |
189393 ns |
190933 ns |
0.99 |
array/iteration/findfirst/bool |
124201 ns |
114451 ns |
1.09 |
array/iteration/findfirst/int |
114811 ns |
116331 ns |
0.99 |
array/iteration/findmin/1d |
169482 ns |
166203 ns |
1.02 |
array/iteration/findmin/2d |
155612 ns |
156173 ns |
1.00 |
array/iteration/logical |
353385 ns |
346025 ns |
1.02 |
array/iteration/scalar |
295995 ns |
289864 ns |
1.02 |
array/permutedims/2d |
74041 ns |
64761 ns |
1.14 |
array/permutedims/3d |
74331 ns |
73791 ns |
1.01 |
array/permutedims/4d |
77391 ns |
76481 ns |
1.01 |
array/random/rand/Float32 |
54581 ns |
51540 ns |
1.06 |
array/random/rand/Int64 |
57501 ns |
56210 ns |
1.02 |
array/random/rand!/Float32 |
146032 ns |
142162 ns |
1.03 |
array/random/rand!/Int64 |
147413 ns |
141832 ns |
1.04 |
array/random/randn/Float32 |
99331 ns |
86921 ns |
1.14 |
array/random/randn!/Float32 |
87222 ns |
152202 ns |
0.57 |
array/reductions/mapreduce/Float32/1d |
130992 ns |
132902 ns |
0.99 |
array/reductions/mapreduce/Float32/dims=1 |
93392 ns |
95052 ns |
0.98 |
array/reductions/mapreduce/Float32/dims=1L |
774481 ns |
777081 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=2 |
97692 ns |
96731 ns |
1.01 |
array/reductions/mapreduce/Float32/dims=2L |
297145 ns |
299584 ns |
0.99 |
array/reductions/mapreduce/Int64/1d |
134432 ns |
133322 ns |
1.01 |
array/reductions/mapreduce/Int64/dims=1 |
95691 ns |
78081 ns |
1.23 |
array/reductions/mapreduce/Int64/dims=1L |
782341 ns |
783471 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2 |
96462 ns |
96252 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2L |
303244 ns |
308254 ns |
0.98 |
array/reductions/reduce/Float32/1d |
133771 ns |
132802 ns |
1.01 |
array/reductions/reduce/Float32/dims=1 |
95121 ns |
94832 ns |
1.00 |
array/reductions/reduce/Float32/dims=1L |
773901 ns |
774621 ns |
1.00 |
array/reductions/reduce/Float32/dims=2 |
97152 ns |
96802 ns |
1.00 |
array/reductions/reduce/Float32/dims=2L |
296735 ns |
307245 ns |
0.97 |
array/reductions/reduce/Int64/1d |
134502 ns |
129672 ns |
1.04 |
array/reductions/reduce/Int64/dims=1 |
95331 ns |
78151 ns |
1.22 |
array/reductions/reduce/Int64/dims=1L |
782231 ns |
781931 ns |
1.00 |
array/reductions/reduce/Int64/dims=2 |
96621 ns |
96192 ns |
1.00 |
array/reductions/reduce/Int64/dims=2L |
296015 ns |
298414 ns |
0.99 |
array/reverse/1d |
44641 ns |
44380 ns |
1.01 |
array/reverse/1dL |
75261 ns |
74131 ns |
1.02 |
array/reverse/1dL_inplace |
127212 ns |
108282 ns |
1.17 |
array/reverse/1d_inplace |
78901 ns |
86471 ns |
0.91 |
array/reverse/2d |
51100 ns |
50661 ns |
1.01 |
array/reverse/2dL |
101772 ns |
100341 ns |
1.01 |
array/reverse/2dL_inplace |
135532 ns |
117622 ns |
1.15 |
array/reverse/2d_inplace |
79061 ns |
95391 ns |
0.83 |
array/sorting/1d |
340625 ns |
341945 ns |
1.00 |
integration/byval/reference |
39621 ns |
38830 ns |
1.02 |
integration/byval/slices=1 |
40880 ns |
40880 ns |
1 |
integration/byval/slices=2 |
145702 ns |
158462 ns |
0.92 |
integration/byval/slices=3 |
237713 ns |
238013 ns |
1.00 |
integration/volumerhs |
5034992 ns |
4942659 ns |
1.02 |
kernel/indexing |
104961 ns |
43630 ns |
2.41 |
kernel/indexing_checked |
131991 ns |
128022 ns |
1.03 |
kernel/launch |
1340 ns |
1290 ns |
1.04 |
kernel/rand |
203133 ns |
106671 ns |
1.90 |
latency/import |
1486594790 ns |
1501349912 ns |
0.99 |
latency/precompile |
11980399330 ns |
12041117438 ns |
0.99 |
latency/ttfp |
10904391209 ns |
10491950084 ns |
1.04 |
This comment was automatically generated by workflow using github-action-benchmark.
|
Yes, we had issues with the runners. I relaunched the failed jobs. Would the current status be ready for review or are you still working on it? |
|
If it passes the CI I'm happy to get it reviewed! |
|
The addition looks good to me. I was curious if you had specific idea in mind on how to achieve
|
|
I haven't given it much thought, but something along the lines of a trait, where we have empty structs RDNA4 and RDNA3 as subtypes of a |
|
Thanks, feel free to open an issue reporting the above thoughts so we could have a trace for future todos. One last thing, would it make sense to explicitly have |
|
It makes sense and I have changed that. I just hope it doesn't break anything downstream for the people who were referring to |
|
@pxl-th is it fine for you to rename the WMMA -> WMMA_3 for consistency with 4? |
As per title. Most things work in the same way as RDNA3, except for the fact that you don't need data duplication as RDNA4 lanes have 8 elements and not 16 (8*2 as it was before). I kept both implementation and tests separate from the RDNA3 version, but maybe in the future they could be kind of merged with a runtime dispatch based on the hardware.
DISCLAIMER: I did use Mistral vibe to get the first draft and then took development from there. I do not know if there is a specific policy against use of AI for pull requests to this project, thus I understand if you're not willing to look at it. The tests pass on my RX 9070XT and even if you diff with the RDNA3 version you'll see that the logic is the same and what changes is the shape of the fragments and addressing in the lanes.